#!/usr/bin/env python
################################################################
#
# plotapot:
#   plot an analytic potfit potential using gnuplot
#
################################################################
#
#   Copyright 2013
#             Institute for Theoretical and Applied Physics
#             University of Stuttgart, D-70550 Stuttgart, Germany
#             https://www.potfit.net/
#
#################################################################
#
#   This file is part of potfit.
#
#   potfit is free software; you can redistribute it and/or modify
#   it under the terms of the GNU General Public License as published by
#   the Free Software Foundation; either version 2 of the License, or
#   (at your option) any later version.
#
#   potfit is distributed in the hope that it will be useful,
#   but WITHOUT ANY WARRANTY; without even the implied warranty of
#   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#   GNU General Public License for more details.
#
#   You should have received a copy of the GNU General Public License
#   along with potfit; if not, see <http://www.gnu.org/licenses/>.
#
#################################################################

import argparse
import functions
import math
import os
import subprocess
import sys

def get_num_atom_types(interaction, cols):
    if interaction == 'pair':
        return int(0.5 * (math.sqrt(1+8*cols)-1))
    elif interaction == 'eam':
        return int(0.5 * (math.sqrt(25+8*cols)-5))
    elif interaction == 'adp':
        return int((math.sqrt(49+24*cols)-7)/6)
    elif interaction == 'meam':
        return int(math.sqrt(4+cols)-2)
    elif interaction == 'stiweb':
        return int(math.sqrt(cols-0.5)-0.5)
    elif interaction == 'tersoff':
        return int(math.sqrt(cols))
    else:
        return 0

def make_partition(interaction, n):
    part = []
    pair = []
    transfer = []

    # pair distribution
    for i in range(int((n*(n+1)/2)/3)):
        pair.append(3)
    if n*(n+1)/2 % 3 != 0:
        pair.append(n%3)

    # eam distribution
    for i in range(n/3):
        transfer.append(3)
    if n % 3 != 0:
        transfer.append(n%3)

    if interaction == 'pair':
        return [pair]
    elif interaction == 'eam':
        return [pair, transfer, transfer]
    elif interaction == 'adp':
        return [pair, transfer, transfer, pair, pair]
    elif interaction == 'meam':
        return [pair, transfer, transfer, pair, transfer]
    elif interaction == 'stiweb':
        pass
    elif interaction == 'tersoff':
        pass
    else:
        return 0

def write_gnuplot_header(f, index, prefix, xra = [0, 6], yra = 'auto'):
    f.write('reset;\n')
    f.write('set termoption solid;\n')
    width = len(partition[index])
    if args.t and not args.p:
        f.write('set term %s title \'%s\' size %s,%s\n' % (args.t, prefix + ' potentials', width*size[0], size[1]))
    else:
        f.write('set term pngcairo;\n')
        f.write('set output \'plot_%s.png\';\n' % prefix)
    f.write('set samples 1000;\n')
    f.write('set grid;\n')
    f.write('set xrange [{}:{}];\n'.format(xra[0],xra[1]))
    if yra != 'auto':
        f.write('set yrange [{}:{}];\n'.format(yra[0],yra[1]))
    f.write('cof(x) = x**4/(1+x**4);\n')

def write_gnuplot_data(f, index, offset, title_string):
    count = offset
    if len(partition[index]) > 1:
        f.write('set multiplot;\n')
        for i in range(len(partition[index])):
            f.write('set origin {},0;\n'.format(i*float(1.0/len(partition[index]))))
            f.write('set size {},1;\n'.format(float(1.0/len(partition[index]))))
            f.write('pl ')
            for j in range(count,count + partition[index][i]):
                if pot_table[j].do_smooth != 0:
                    f.write('cof((x-{0})/{1})*('.format(pot_table[j].cutoff,pot_table[j].params[-1][1]))
                f.write(func_strings[j])
                if pot_table[j].do_smooth != 0:
                    f.write(')')
                f.write(' w l')
                f.write(' t \'%s\'' % title_string[j])
                f.write(' lt 0 lw 2 lc %s' % (j+1-count))
                if index == 0 and args.pair_dist_file:
                    f.write(', \'%s\' i %s axes x1y2 w steps lw 1.2 lc %s t \'g(r) %s\'' %
                            (args.pair_dist_file, j, (j+1-count), title_string[j]))
                if j < count + partition[index][i] - 1:
                    f.write(', \\\n\t')
            f.write(';\n')
            count += partition[index][i]
    else:
        f.write('pl ')
        for i in range(count,count + partition[index][0]):
            if pot_table[i].do_smooth != 0:
                f.write('cof((x-{0})/{1})*('.format(pot_table[i].cutoff,pot_table[i].params[-1][1]))
            f.write(func_strings[i])
            if pot_table[i].do_smooth != 0:
                f.write(')')
            f.write(' w l')
            f.write(' t \'%s\'' % title_string[i-count])
            f.write(' lt 0 lw 2 lc %s' % (i+1-count))
            if index == 0 and args.pair_dist_file:
                f.write(', \'%s\' i %s axes x1y2 w steps lw 1.2 lc %s t \'g(r) %s\'' %
                        (args.pair_dist_file, i-count, (i+1-count), title_string[i-count]))
            if i < count + partition[index][0] - 1:
                f.write(', \\\n\t')
        f.write(';\n')

parser = argparse.ArgumentParser(description='''Plot an analytic potfit potential using gnuplot.''')
parser.add_argument('-f', action='store_true', help='write plotfiles for gnuplot and exit')
parser.add_argument('-p', action='store_true', help='only create plots in png format but do not show them')
parser.add_argument('-s', metavar='size', type=str, default='640x480',
        help='window size of the gnuplot window (default: 640x480)')
parser.add_argument('-t', metavar='terminal', type=str, default='qt',
        help='use this gnuplot terminal (see the \'help terminal\' command in gnuplot for details)')
parser.add_argument('pot_file', type=str, help='potfit potential file')
parser.add_argument('pair_dist_file', type=str, help='pair distribution file', nargs='?')
args = parser.parse_args()

try:
    f = open(args.pot_file)
except:
    sys.stderr.write("Could not open potential file '%s'.\n" % args.pot_file)
    sys.exit()

try:
    size = args.s.split('x')
    if len(size) != 2:
        sys.stderr.write('Error: Could not read -s parameter\n')
        sys.stderr.write('\tTry -s 800x600\n')
        sys.exit()
    size = [int(s) for s in size]
except:
    sys.stderr.write('Error: Could not read -s parameter\n')
    sys.exit()

num_global = 0
globals_table = dict()
pot_table = []
atom_types = []

# read data from potential file
line = f.readline()
while line != '':
    data = line.split()
    if line[0] == '#':
        # header
        if line[1] == 'F':
            if data[1] != '0':
                sys.stderr.write('This tool is designed for analytic potentials!\n')
                sys.stderr.write('Your potential file format is wrong.\n')
                sys.exit()
            cols = int(data[2])
        if line[1] == 'T':
            interaction = data[1].lower()
            if interaction in ['stiweb', 'tersoff']:
                sys.stderr.write('Error: Currently %s potentials are not supported\n' % interaction)
                sys.exit()
        if line[1] == 'C':
            atom_types = data[1:]
        pass
    else:
        if len(data) > 0 and data[0] == 'global':
            num_global = int(data[1])
            for i in range(num_global):
                line = f.readline().split()
                globals_table[line[0]] = float(line[1])
        if len(data) > 0 and data[0] == 'type':
            name = data[1]
            do_smooth = 0
            rmin = 0.0
            if name.endswith('_sc'):
                do_smooth = 1
                name = name[:-3]
            if name == 'pohlong':
                name = 'bjs'
            try:
                pot_table.append(getattr(functions, name)())
            except:
                sys.stderr.write("Could not find the function type >> %s <<\n\n" % name)
                sys.exit()
            while line.split()[0] != 'cutoff':
                line = f.readline()
            cutoff = float(line.split()[1])
            line = f.readline()
            while line[0] == '#':
                if line.split()[1] == 'rmin':
                    rmin = float(line.split()[2])
                line = f.readline()
            pot_table[-1].set_properties(cutoff, do_smooth, False, rmin)
            num_params = pot_table[-1].get_num_params()
            params = []
            for i in range(num_params):
                param_data = line.split()
                if param_data[0].endswith('!'):
                    params.append([param_data[0][:-1],float(globals_table[param_data[0][:-1]])])
                else:
                    params.append([param_data[0],float(param_data[1])])
                line = f.readline()
            pot_table[-1].set_params(params)
    line = f.readline()
f.close()

num_atom_types = get_num_atom_types(interaction,cols)
pair_col = num_atom_types * (num_atom_types + 1) / 2

func_strings = []
for i in range(len(pot_table)):
    par_data = []
    for j in range(len(pot_table[i].params)):
        par_data.append(pot_table[i].params[j][1])
    func_strings.append(pot_table[i].get_function_string().format(*par_data))

partition = make_partition(interaction, num_atom_types)

# calculate cutoff
cutoff = [0]*len(partition)
rmin = [99]*len(partition)
count = 0
for i in range(len(partition)):
    for j in range(sum(partition[i])):
        cutoff[i] = max(cutoff[i],pot_table[count + j].cutoff)
        rmin[i] = min(rmin[i],0.8 * pot_table[count + j].rmin)
    count += sum(partition[i])

# create title strings
title_pair = []
title_transfer = []
for i in range(num_atom_types):
    for j in range(i, num_atom_types):
        if atom_types:
            title_pair.append(atom_types[i]+'-'+atom_types[j])
        else:
            title_pair.append(str(i)+'-'+str(j))
    if atom_types:
        title_transfer.append(atom_types[i])
    else:
        title_transfer.append(str(i))

# write pair potentials
f = open('plot_pair', 'w')
write_gnuplot_header(f, 0, 'pair', [rmin[0],cutoff[0]], [-.3,.5])
write_gnuplot_data(f, 0, 0, title_pair)
f.close()

# write eam potentials
if interaction in ['eam', 'adp', 'meam']:
    f = open('plot_transfer', 'w')
    write_gnuplot_header(f, 1, 'transfer', [rmin[1],cutoff[1]], [-.3,.5])
    write_gnuplot_data(f, 1, pair_col, title_transfer)
    f.close()

    f = open('plot_embedding', 'w')
    write_gnuplot_header(f, 2, 'embedding', [rmin[2],cutoff[2]])
    write_gnuplot_data(f, 2, pair_col + num_atom_types, title_transfer)
    f.close()

# write adp potentials
if interaction == 'adp':
    f = open('plot_upot', 'w')
    write_gnuplot_header(f, 3, 'dipole', [rmin[3],cutoff[3]], [-.3,.5])
    write_gnuplot_data(f, 3, pair_col + 2 * num_atom_types, title_pair)
    f.close()

    f = open('plot_wpot', 'w')
    write_gnuplot_header(f, 4, 'quadrupole', [rmin[4],cutoff[4]], [-.3,.5])
    write_gnuplot_data(f, 4, 2 * pair_col + 2 * num_atom_types, title_pair)
    f.close()

if interaction == 'meam':
    f = open('plot_f', 'w')
    write_gnuplot_header(f, 3, 'f', [rmin[3],cutoff[3]], [-.3,.5])
    write_gnuplot_data(f, 3, pair_col + 2 * num_atom_types, title_pair)
    f.close()

    f = open('plot_g', 'w')
    write_gnuplot_header(f, 4, 'g', [-1.1,1.1])
    write_gnuplot_data(f, 4, 2 * pair_col + 2 * num_atom_types, title_pair)
    f.close()

if not args.f:
    subprocess.call(["gnuplot", "-persist", "plot_pair"])
    if interaction in ['eam', 'adp', 'meam']:
        subprocess.call(["gnuplot", "-persist", "plot_transfer"])
        subprocess.call(["gnuplot", "-persist", "plot_embedding"])
    if interaction == 'adp':
        subprocess.call(["gnuplot", "-persist", "plot_upot"])
        subprocess.call(["gnuplot", "-persist", "plot_wpot"])
    elif interaction == 'meam':
        subprocess.call(["gnuplot", "-persist", "plot_f"])
        subprocess.call(["gnuplot", "-persist", "plot_g"])

if args.p:
    os.remove('plot_pair')
    if interaction in ['eam', 'adp', 'meam']:
        os.remove('plot_transfer')
        os.remove('plot_embedding')
    if interaction == 'adp':
        os.remove('plot_upot')
        os.remove('plot_wpot')
    elif interaction == 'meam':
        os.remove('plot_f')
        os.remove('plot_g')
